Simple Examples

This tutorial goes through a few common ML tasks using the cremi dataset and a 2D U-Net.

Introduction and overview

In this tutorial we will cover a few basic ML tasks using the DaCapo toolbox. We will:

  • Prepare a dataloader for the CREMI dataset

  • Train a simple 2D U-Net for both instance and semantic segmentation

  • Visualize the results

Environment setup

If you have not already done so, you will need to install DaCapo. You can do this by first creating a new environment and then installing the DaCapo Toolbox.

I highly recommend using uv for environment management, but there are many tools to choose from.

uv init
uv add git+https://github.com/pattonw/dacapo-toolbox.git

Data Preparation

DaCapo works with zarr, so we will download CREMI Sample A and save it as a zarr file.

import wget
from pathlib import Path
import dask

dask.config.set(scheduler="single-threaded")

# Download some cremi data
# immediately convert it to zarr for convenience
if not Path("sample_A_20160501.hdf").exists():
    wget.download(
        "https://cremi.org/static/data/sample_C_20160501.hdf", "sample_C_20160501.hdf"
    )
    wget.download(
        "https://cremi.org/static/data/sample_A_20160501.hdf", "sample_A_20160501.hdf"
    )

Data Loading

We will use the funlib.persistence library to interface with zarr. This library adds support for units, voxel size, and axis names along with the ability to query our data based on a Roi object describing a specific rectangular piece of data. This is especially useful in a microscopy context where you regularly need to chunk your data for processing.

import numpy as np
from funlib.persistence import prepare_ds, open_ds
import h5py
from pathlib import Path
import re
if not Path("cremi.zarr/train/raw").exists():
    test = h5py.File("sample_C_20160501.hdf", "r")
    raw_data = test["volumes/raw"][:]
    labels_data = test["volumes/labels/neuron_ids"][:]
    test_raw = prepare_ds(
        "cremi.zarr/test/raw",
        raw_data.shape,
        voxel_size=(40, 4, 4),
        dtype=raw_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    test_raw[test_raw.roi] = raw_data
    test_labels = prepare_ds(
        "cremi.zarr/test/labels",
        labels_data.shape,
        voxel_size=(40, 4, 4),
        dtype=labels_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    test_labels[test_labels.roi] = labels_data
    train = h5py.File("sample_A_20160501.hdf", "r")
    raw_data = train["volumes/raw"][:]
    labels_data = train["volumes/labels/neuron_ids"][:]
    train_raw = prepare_ds(
        "cremi.zarr/train/raw",
        raw_data.shape,
        voxel_size=(40, 4, 4),
        dtype=raw_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    train_raw[train_raw.roi] = raw_data
    train_labels = prepare_ds(
        "cremi.zarr/train/labels",
        labels_data.shape,
        voxel_size=(40, 4, 4),
        dtype=labels_data.dtype,
        axis_names=["z", "y", "x"],
        units=["nm", "nm", "nm"],
    )
    train_labels[train_labels.roi] = labels_data
else:
    train_raw = open_ds("cremi.zarr/train/raw")
    train_labels = open_ds("cremi.zarr/train/labels")
    test_raw = open_ds("cremi.zarr/test/raw")
    test_labels = open_ds("cremi.zarr/test/labels")

Lets visualize our train and test data

# a custom label color map for showing instances
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.animation as animation
from IPython.display import HTML
import matplotlib as mpl

mpl.rcParams["animation.embed_limit"] = 50_000_000  # 50 MB, for example

# Create a custom label color map for showing instances
np.random.seed(1)
colors = [[0, 0, 0]] + [list(np.random.choice(range(256), size=3)) for _ in range(255)]
label_cmap = ListedColormap(colors)

Training data

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

ims = []
for i, (x, y) in enumerate(zip(train_raw.data, train_labels.data)):
    # Show the raw data
    if i == 0:
        im = axes[0].imshow(x)
        axes[0].set_title("Raw Train Data")
        im2 = axes[1].imshow(
            y % 256, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
        )
        axes[1].set_title("Train Labels")
    else:
        im = axes[0].imshow(x, animated=True)
        im2 = axes[1].imshow(
            y % 256,
            cmap=label_cmap,
            vmin=0,
            vmax=255,
            animated=True,
            interpolation="none",
        )
    ims.append([im, im2])

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="800" ', video_html)
HTML(video_html)
WARNING:matplotlib.animation:MovieWriter stderr:
Received > 3 system signals, hard exiting
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:224, in AbstractMovieWriter.saving(self, fig, outfile, dpi, *args, **kwargs)
    223 try:
--> 224     yield self
    225 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1126, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
   1125         frame_number += 1
-> 1126 writer.grab_frame(**savefig_kwargs)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:352, in MovieWriter.grab_frame(self, **savefig_kwargs)
    351 # Save the figure data to the sink, using the frame format and dpi.
--> 352 self.fig.savefig(self._proc.stdin, format=self.frame_format,
    353                  dpi=self.dpi, **savefig_kwargs)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3490, in Figure.savefig(self, fname, transparent, **kwargs)
   3489         _recursively_make_axes_transparent(stack, ax)
-> 3490 self.canvas.print_figure(fname, **kwargs)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2184, in FigureCanvasBase.print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)
   2183     with cbook._setattr_cm(self.figure, dpi=dpi):
-> 2184         result = print_method(
   2185             filename,
   2186             facecolor=facecolor,
   2187             edgecolor=edgecolor,
   2188             orientation=orientation,
   2189             bbox_inches_restore=_bbox_inches_restore,
   2190             **kwargs)
   2191 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backend_bases.py:2040, in FigureCanvasBase._switch_canvas_and_return_print_method.<locals>.<lambda>(*args, **kwargs)
   2039     skip = optional_kws - {*inspect.signature(meth).parameters}
-> 2040     print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
   2041         *args, **{k: v for k, v in kwargs.items() if k not in skip}))
   2042 else:  # Let third-parties do as they see fit.

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:417, in FigureCanvasAgg.print_raw(self, filename_or_obj, metadata)
    416     raise ValueError("metadata not supported for raw/rgba")
--> 417 FigureCanvasAgg.draw(self)
    418 renderer = self.get_renderer()

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/backends/backend_agg.py:382, in FigureCanvasAgg.draw(self)
    380 with (self.toolbar._wait_cursor_for_draw_cm() if self.toolbar
    381       else nullcontext()):
--> 382     self.figure.draw(self.renderer)
    383     # A GUI class may be need to update a window using this draw, so
    384     # don't forget to call the superclass.

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:94, in _finalize_rasterization.<locals>.draw_wrapper(artist, renderer, *args, **kwargs)
     92 @wraps(draw)
     93 def draw_wrapper(artist, renderer, *args, **kwargs):
---> 94     result = draw(artist, renderer, *args, **kwargs)
     95     if renderer._rasterizing:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69         renderer.start_filter()
---> 71     return draw(artist, renderer)
     72 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/figure.py:3257, in Figure.draw(self, renderer)
   3256 self.patch.draw(renderer)
-> 3257 mimage._draw_list_compositing_images(
   3258     renderer, self, artists, self.suppressComposite)
   3260 renderer.close_group('figure')

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:134, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    133     for a in artists:
--> 134         a.draw(renderer)
    135 else:
    136     # Composite any adjacent images together

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69         renderer.start_filter()
---> 71     return draw(artist, renderer)
     72 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/axes/_base.py:3210, in _AxesBase.draw(self, renderer)
   3208     _draw_rasterized(self.get_figure(root=True), artists_rasterized, renderer)
-> 3210 mimage._draw_list_compositing_images(
   3211     renderer, self, artists, self.get_figure(root=True).suppressComposite)
   3213 renderer.close_group('axes')

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:134, in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
    133     for a in artists:
--> 134         a.draw(renderer)
    135 else:
    136     # Composite any adjacent images together

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/artist.py:71, in allow_rasterization.<locals>.draw_wrapper(artist, renderer)
     69         renderer.start_filter()
---> 71     return draw(artist, renderer)
     72 finally:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:609, in _ImageBase.draw(self, renderer)
    608 else:
--> 609     im, l, b, trans = self.make_image(
    610         renderer, renderer.get_image_magnification())
    611     if im is not None:

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:912, in AxesImage.make_image(self, renderer, magnification, unsampled)
    910 clip = ((self.get_clip_box() or self.axes.bbox) if self.get_clip_on()
    911         else self.get_figure(root=True).bbox)
--> 912 return self._make_image(self._A, bbox, transformed_bbox, clip,
    913                         magnification, unsampled=unsampled)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/image.py:511, in _ImageBase._make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)
    509         output_alpha = _resample(  # resample alpha channel
    510             self, A[..., 3], out_shape, t)
--> 511     output = _resample(  # resample rgb channels
    512         self, _rgb_to_rgba(A[..., :3]), out_shape, t)
    513 elif np.ndim(alpha) > 0:  # Array alpha
    514     # user-specified array alpha overrides the existing alpha channel

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/ipykernel/kernelapp.py:600, in IPKernelApp.sigint_handler(self, *args)
    599 elif self.kernel.shell_is_blocking:
--> 600     raise KeyboardInterrupt

KeyboardInterrupt: 

During handling of the above exception, another exception occurred:

CalledProcessError                        Traceback (most recent call last)
Cell In[5], line 27
     25 ims = ims + ims[::-1]
     26 ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
---> 27 video_html = ani.to_html5_video()
     28 video_html = re.sub(r"<video ", '<video width="800" ', video_html)
     29 HTML(video_html)

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1306, in Animation.to_html5_video(self, embed_limit)
   1302 Writer = writers[mpl.rcParams['animation.writer']]
   1303 writer = Writer(codec='h264',
   1304                 bitrate=mpl.rcParams['animation.bitrate'],
   1305                 fps=1000. / self._interval)
-> 1306 self.save(str(path), writer=writer)
   1307 # Now open and base64 encode.
   1308 vid64 = base64.encodebytes(path.read_bytes())

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:1098, in Animation.save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs, progress_callback)
   1093     return a * np.array([r, g, b]) + 1 - a
   1095 # canvas._is_saving = True makes the draw_event animation-starting
   1096 # callback a no-op; canvas.manager = None prevents resizing the GUI
   1097 # widget (both are likewise done in savefig()).
-> 1098 with (writer.saving(self._fig, filename, dpi),
   1099       cbook._setattr_cm(self._fig.canvas, _is_saving=True, manager=None)):
   1100     if not writer._supports_transparency():
   1101         facecolor = savefig_kwargs.get('facecolor',
   1102                                        mpl.rcParams['savefig.facecolor'])

File ~/.local/share/uv/python/cpython-3.10.17-linux-x86_64-gnu/lib/python3.10/contextlib.py:153, in _GeneratorContextManager.__exit__(self, typ, value, traceback)
    151     value = typ()
    152 try:
--> 153     self.gen.throw(typ, value, traceback)
    154 except StopIteration as exc:
    155     # Suppress StopIteration *unless* it's the same exception that
    156     # was passed to throw().  This prevents a StopIteration
    157     # raised inside the "with" statement from being suppressed.
    158     return exc is not value

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:226, in AbstractMovieWriter.saving(self, fig, outfile, dpi, *args, **kwargs)
    224     yield self
    225 finally:
--> 226     self.finish()

File ~/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/matplotlib/animation.py:341, in MovieWriter.finish(self)
    337     _log.log(
    338         logging.WARNING if self._proc.returncode else logging.DEBUG,
    339         "MovieWriter stderr:\n%s", err)
    340 if self._proc.returncode:
--> 341     raise subprocess.CalledProcessError(
    342         self._proc.returncode, self._proc.args, out, err)

CalledProcessError: Command '['ffmpeg', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', '1200x600', '-pix_fmt', 'rgba', '-framerate', '5.0', '-loglevel', 'error', '-i', 'pipe:', '-vcodec', 'h264', '-pix_fmt', 'yuv420p', '-y', '/tmp/tmpgvsdbg8s/temp.m4v']' returned non-zero exit status 123.
../_images/897508d9f92d81264747f500785362e3cacddb19bfc4fef9b896d5d65b1cb86a.png

Testing data

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

ims = []
for i, (x, y) in enumerate(zip(test_raw.data, test_labels.data)):
    if i == 0:
        im = axes[0].imshow(x)
        axes[0].set_title("Raw Test Data")
        im2 = axes[1].imshow(
            y % 256, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
        )
        axes[1].set_title("Test Labels")
    else:
        im = axes[0].imshow(x, animated=True)
        im2 = axes[1].imshow(
            y % 256,
            cmap=label_cmap,
            vmin=0,
            vmax=255,
            animated=True,
            interpolation="none",
        )
    ims.append([im, im2])

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="800" ', video_html)
HTML(video_html)
../_images/ecf72298ccc9332cc60462f48684a6f51dd76c9ae5818f489e01635158443108.png

DaCapo

Now that we have some data, lets look at how we can use DaCapo to interface with it for some common ML use cases.

Data Split

We always want to be explicit when we define our data split for training and validation so that we are aware what data is being used for training and validation.

from dacapo_toolbox.datasplits import SimpleDataSplitConfig
datasplit = SimpleDataSplitConfig(
    name="cremi",
    path="cremi.zarr",
)
print(f"Train datasets: {datasplit.train}")
print(f"Validation datasets: {datasplit.validate}")
Train datasets: [SimpleDataset(name='train', path=PosixPath('cremi.zarr/train'), weight=1, raw_name='raw', gt_name='labels', mask_name='mask')]
Validation datasets: [SimpleDataset(name='test', path=PosixPath('cremi.zarr/test'), weight=1, raw_name='raw', gt_name='labels', mask_name='mask')]

Augmentation

We almost always want to use rotations when training in EM data. This is because the structures we care about rarely have strict orientations relative to the zyx axes. However because we usually some axial anisotropy in our data, we want to be careful about how we apply these rotations.

from dacapo_toolbox.trainers import GunpowderTrainerConfig
from dacapo_toolbox.trainers.gp_augments import ElasticAugmentConfig

# build a trainer config with elastic deformations accounting for the anisotropy
trainer = GunpowderTrainerConfig(
    name="rotations",
    augments=[
        ElasticAugmentConfig(
            control_point_spacing=(2, 20, 20),
            control_point_displacement_sigma=(2, 20, 20),
            rotation_interval=(0, 3.14),
            subsample=4,
            uniform_3d_rotation=False,  # rotate only in 2D
            augmentation_probability=0.5,
        )
    ],
)
/home/runner/work/dacapo-toolbox/dacapo-toolbox/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Simple Training loop

The Trainer is only useful when combined with some data, but now that we have defined some data via the DataSplitConfig and the pipeline via the TrainerConfig, we can visualize a batches:

import torch

z_slices = 13
batch_size = 3

torch_dataset = trainer.iterable_dataset(
    datasets=datasplit.train,
    input_shape=(z_slices, 128, 128),
    output_shape=(z_slices, 128, 128),
)

dataloader = torch.utils.data.DataLoader(
    torch_dataset, batch_size=batch_size, num_workers=0
)


batch = next(iter(dataloader))
fig, axes = plt.subplots(batch_size, 2, figsize=(12, 18))

ims = []
for zz in range(z_slices):
    b_ims = []
    for bb in range(batch_size):
        b_raw = batch["raw"][bb, 0, zz].numpy()
        b_labels = batch["gt"][bb, zz].numpy() % 256
        if zz == 0:
            im = axes[bb, 0].imshow(b_raw)
            im2 = axes[bb, 1].imshow(
                b_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
            )
            if bb == 0:
                axes[bb, 0].set_title("Sample Raw")
                axes[bb, 1].set_title("Sample Labels")
        else:
            im = axes[bb, 0].imshow(b_raw, animated=True)
            im2 = axes[bb, 1].imshow(
                b_labels,
                cmap=label_cmap,
                vmin=0,
                vmax=255,
                animated=True,
                interpolation="none",
            )
        b_ims.extend([im, im2])
    ims.append(b_ims)

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="800" ', video_html)
HTML(video_html)
../_images/ed4254b0ebc2555cde5b9881801db887384978d40dfb4691050a63dd4da43c7b.png

Tasks

When training for instance segmentation, it is not possible to directly predict label ids since the ids have to be unique accross the full volume which is not possible to do with the local context that a UNet operates on. So instead we need to transform our labels into some intermediate representation that is both easy to predict and easy to post process. The most common method we use is a combination of affinities with optional lsds for prediction plus mutex watershed for post processing.

Next we will define the task that encapsulates this process.

from dacapo_toolbox.tasks import AffinitiesTaskConfig

affs_config = AffinitiesTaskConfig(
    name="affs",
    neighborhood=[
        [0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 9],
        [0, 9, 0],
        [0, 0, 27],
        [0, 27, 0],
    ],
    # lsds=True,
)

torch_dataset = trainer.iterable_dataset(
    datasets=datasplit.train,
    input_shape=(z_slices, 128, 128),
    output_shape=(z_slices, 128, 128),
    task=affs_config,
)

dataloader = torch.utils.data.DataLoader(
    torch_dataset, batch_size=batch_size, num_workers=0
)

batch = next(iter(dataloader))
fig, axes = plt.subplots(batch_size, 3, figsize=(18, 18))
ims = []
for zz in range(z_slices):
    b_ims = []
    for bb in range(batch_size):
        b_raw = batch["raw"][bb, 0, zz].numpy()
        b_labels = batch["gt"][bb, zz].numpy() % 256
        b_target = batch["target"][bb, [0, 5, 6], zz].numpy()
        if zz == 0:
            im = axes[bb, 0].imshow(b_raw)
            im2 = axes[bb, 1].imshow(
                b_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
            )
            im3 = axes[bb, 2].imshow(b_target.transpose(1, 2, 0), interpolation="none")
            if bb == 0:
                axes[bb, 0].set_title("Sample Raw")
                axes[bb, 1].set_title("Sample Labels")
                axes[bb, 2].set_title("Sample Affinities")
        else:
            im = axes[bb, 0].imshow(b_raw, animated=True)
            im2 = axes[bb, 1].imshow(
                b_labels,
                cmap=label_cmap,
                vmin=0,
                vmax=255,
                animated=True,
                interpolation="none",
            )
            im3 = axes[bb, 2].imshow(
                b_target.transpose(1, 2, 0), animated=True, interpolation="none"
            )
        b_ims.extend([im, im2, im3])
    ims.append(b_ims)

ims = ims + ims[::-1]
ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
video_html = ani.to_html5_video()
video_html = re.sub(r"<video ", '<video width="800" ', video_html)
HTML(video_html)
../_images/c957669a2708262d8acec3c0f68fd7966376a00b14fb09c53734f1ed423873ff.png

Models

DaCapo lets you easily train any model you want, with a special wrapper for UNets specifically. Lets make one now.

from dacapo_toolbox.architectures import CNNectomeUNetConfig
from funlib.geometry import Coordinate, Roi

input_shape = Coordinate((5, 156, 156))

unet_config = CNNectomeUNetConfig(
    name="2.5D_UNet",
    input_shape=input_shape,
    fmaps_in=1,
    fmaps_out=32,
    num_fmaps=32,
    fmap_inc_factor=4,
    downsample_factors=[(1, 2, 2), (1, 2, 2), (1, 2, 2)],
    kernel_size_down=[
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
    ],
    kernel_size_up=[
        [(1, 3, 3), (1, 3, 3)],
        [(1, 3, 3), (1, 3, 3)],
        [(3, 3, 3), (3, 3, 3)],
    ],
)

output_shape = unet_config.compute_output_shape(input_shape)
print(f"Given an input of shape {input_shape} we get an out of shape {output_shape}")
Given an input of shape (5, 156, 156) we get an out of shape (1, 64, 64)

Training loop

Now we can bring everything together and train our model.

dataset = trainer.iterable_dataset(
    datasets=datasplit.train,
    input_shape=input_shape,
    output_shape=output_shape,
    task=affs_config,
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=3,
    prefetch_factor=2,
    persistent_workers=True,
)


task = affs_config.task_type(affs_config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this ensures we output the appropriate number of channels, use the appropriate final activation etc.
module = task.create_model(unet_config).to(device)
loss = task.loss
optimizer = torch.optim.Adam(module.parameters(), lr=1e-4)

losses = []

print(f"Training on {device}")
for iteration, batch in enumerate(iter(dataloader)):
    raw, target, weight = (
        batch["raw"].to(device),
        batch["target"].to(device),
        batch["weight"].to(device),
    )
    optimizer.zero_grad()
    output = module(raw)
    loss_value = loss.compute(output, target, weight)
    loss_value.backward()
    optimizer.step()
    print(f"Loss ({iteration}): {loss_value.item():.3f}")

    losses.append(loss_value.item())

    if iteration >= 10:
        break
Training on cpu
Loss (0): 0.597
Loss (1): 0.659
Loss (2): 0.836
Loss (3): 0.585
Loss (4): 0.944
Loss (5): 0.609
Loss (6): 0.599
Loss (7): 0.718
Loss (8): 0.663
Loss (9): 0.598
Loss (10): 0.592
plt.plot(losses)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.show()
../_images/a50ea5128eb1a775c7001f743f01bba6f273e79f31da404e3610c5d681ee279d.png
import mwatershed as mws

# Lets predict on some validation data:
val_raw, val_gt = datasplit.validate[0].raw, datasplit.validate[0].gt
# fetch a xy slice from the center of our validation volume
# We snap to grid to a multiple of the max downsampling factor of
# the unet (1, 8, 8) to ensure downsampling is always possible
roi = val_raw.roi
z_coord = Coordinate(1, 0, 0)
xy_coord = Coordinate(0, 1, 1)
center_offset = roi.center * z_coord + roi.offset * xy_coord
center_size = val_raw.voxel_size * z_coord + (roi.shape * xy_coord) // 2
center_slice = Roi(center_offset, center_size)
center_slice = center_slice.snap_to_grid(val_raw.voxel_size * Coordinate(1, 8, 8))
context = (input_shape - output_shape) // 2 * val_raw.voxel_size

# Read the raw data
raw_input = val_raw.to_ndarray(center_slice.grow(context, context))

# Predict on the validation data
with torch.no_grad():
    device = torch.device("cpu")
    module = module.to(device)
    pred = (
        module(torch.from_numpy(raw_input).to(device).unsqueeze(0).unsqueeze(0))
        .cpu()
        .detach()
        .numpy()
    )
# Plot the results
fig, axes = plt.subplots(1, 4, figsize=(24, 8))
padding = (input_shape - output_shape) // 2

# select the long range affinity channels for visualization
prediction = pred[0, [0, 5, 6], 0]

# Run mutex watershed on the affinity predictions.
# We subtract 0.5 to move affs from range (0, 1) to (-0.5, 0.5).
# This is because mutex only splits objects on negative edges.
pred_labels = (
    mws.agglom(pred[0].astype(np.float64) - 0.5, offsets=affs_config.neighborhood)[0]
    % 256
)

# read the ground truth labels
gt_labels = val_gt.to_ndarray(center_slice)[0] % 256

# Pad
prediction = np.pad(
    prediction,
    ((0,), (padding[1],), (padding[2],)),
    mode="constant",
    constant_values=np.nan,
)
pred_labels = np.pad(
    pred_labels,
    ((padding[1],), (padding[2],)),
    mode="constant",
    constant_values=0,
)
gt_labels = np.pad(
    gt_labels,
    ((padding[1],), (padding[2],)),
    mode="constant",
    constant_values=0,
)

# Plot the results
im_raw = axes[0].imshow(raw_input[2])
im2 = axes[1].imshow(gt_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none")
im4 = axes[2].imshow(prediction.transpose(1, 2, 0), interpolation="none")
im5 = axes[3].imshow(
    pred_labels, cmap=label_cmap, vmin=0, vmax=255, interpolation="none"
)
axes[0].set_title("Val Raw")
axes[1].set_title("Val Labels")
axes[2].set_title("Pred Affinities")
axes[3].set_title("Pred Labels")
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.3228953..2.0344982].
../_images/758ff76c45b9d6e2ea1654c5bd4b3109db762b67b76ecdeec698d1260441c6d2.png